{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# UnifyDTypeScale\n",
    "\n",
    "参考：`tvm/src/relay/quantize/realize.cc`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```c++\n",
    "float ChooseDomScale(const std::vector<const QRealizeIntExprNode*>& nptrs) {\n",
    "  if (nptrs.size() == 2) {\n",
    "    // x = a * s1, y = b * s2\n",
    "    // x + y = (a * s1 / s2 + b) * s2, if s1 > s2\n",
    "    //       = (a + b * s2 / s1) * s1, if s2 > s1\n",
    "    float s1 = GetScalarFromConstant<float>(nptrs[0]->dom_scale);\n",
    "    float s2 = GetScalarFromConstant<float>(nptrs[1]->dom_scale);\n",
    "    return s1 > s2 ? s2 : s1;\n",
    "  } else {\n",
    "    const QConfig& cfg = QConfig::Current();\n",
    "    float scale = cfg->global_scale;\n",
    "    return scale / std::pow(2.0, cfg->nbit_activation - 1);\n",
    "  }\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码定义了一个名为 `ChooseDomScale` 的函数，用于选择两个节点中较小的一个作为它们的共同量化比例。\n",
    "\n",
    "函数接收 `QRealizeIntExprNode` 指针的向量 `nptrs` 作为参数。如果向量的大小为 2，则根据两个节点的量化比例计算它们的和，并返回较小的那个比例。具体来说，如果 `s1 > s2`，则返回 `s2`；否则返回 `s1`。\n",
    "\n",
    "如果向量的大小不为 2，则获取当前的量化配置（`QConfig::Current()`），并返回全局比例除以 2 的 `cfg->nbit_activation - 1` 次方的结果。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "````{dropdown}\n",
    "```c++\n",
    "\n",
    "/* \\brief Unify the dom scale of arguments */\n",
    "Array<Expr> UnifyDTypeScale(const Array<Expr>& ref_args, const Array<Expr>& args,\n",
    "                            DataType* dtype_ptr, Expr* scale_ptr,\n",
    "                            DataType dtype = DataType::Void()) {\n",
    "  static const Op& simulated_quantize = Op::Get(\"relay.op.annotation.simulated_quantize\");\n",
    "  const QConfig& cfg = QConfig::Current();\n",
    "\n",
    "  std::vector<const QRealizeIntExprNode*> nptrs;\n",
    "  Array<Expr> ret;\n",
    "  for (auto arg : args) {\n",
    "    const auto* nptr = arg.as<QRealizeIntExprNode>();\n",
    "    ICHECK(nptr);\n",
    "    nptrs.push_back(nptr);\n",
    "    ret.push_back(nptr->data);\n",
    "  }\n",
    "\n",
    "  // unify the data type\n",
    "  ICHECK_EQ(ref_args.size(), args.size());\n",
    "\n",
    "  if (dtype.is_void()) {\n",
    "    if (ret.size() == 2 && nptrs[1]->dtype == cfg->dtype_input) {\n",
    "      dtype = cfg->dtype_input;\n",
    "    } else {\n",
    "      dtype = cfg->dtype_activation;\n",
    "    }\n",
    "  }\n",
    "\n",
    "  for (size_t i = 0; i < ret.size(); ++i) {\n",
    "    auto ref_arg = ref_args[i].as<CallNode>();\n",
    "    if (nptrs[i]->dtype != dtype) {\n",
    "      ret.Set(i, Cast(ret[i], dtype));\n",
    "    } else if (ref_arg && ref_arg->op.same_as(simulated_quantize) &&\n",
    "               ref_arg->attrs.as<SimulatedQuantizeAttrs>()->kind == kQInput) {\n",
    "      auto new_arg = Cast(ret[i], cfg->dtype_input);\n",
    "      new_arg = StopFusion(new_arg);\n",
    "      ret.Set(i, Cast(new_arg, dtype));\n",
    "    }\n",
    "  }\n",
    "\n",
    "  // unify the dom_scale\n",
    "  float s = ChooseDomScale(nptrs);\n",
    "  Expr dom_scale = MakeConstantScalar(DataType::Float(32), s);\n",
    "  for (size_t i = 0; i < ret.size(); ++i) {\n",
    "    float cur_s = GetScalarFromConstant<float>(nptrs[i]->dom_scale);\n",
    "    ret.Set(i, MulAndDiv(ret[i], cur_s, s, dtype, ref_args[i]->type_as<TensorTypeNode>()->shape));\n",
    "  }\n",
    "\n",
    "  *dtype_ptr = dtype;\n",
    "  *scale_ptr = dom_scale;\n",
    "  return ret;\n",
    "}\n",
    "```\n",
    "````"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这段代码定义了一个名为`UnifyDTypeScale`的函数，用于统一参数的数据类型和量化比例。\n",
    "\n",
    "函数接收5个参数：\n",
    "- `ref_args`：参考参数；\n",
    "- `args`：需要处理的参数；\n",
    "- `dtype_ptr`：指向数据类型的指针；\n",
    "- `scale_ptr`：指向量化比例的指针；\n",
    "- `dtype`：默认为空的数据类型。\n",
    "\n",
    "函数首先获取当前的量化配置（`QConfig::Current()`），然后遍历`args`中的每个参数，将其转换为`QRealizeIntExprNode`类型，并将其添加到`nptrs`向量中。同时，将每个参数的数据部分添加到`ret`数组中。\n",
    "\n",
    "接下来，函数检查是否需要统一数据类型。如果`dtype`为空，则根据`ret`的大小和`cfg->dtype_input`的值来设置`dtype`。否则，使用给定的`dtype`值。\n",
    "\n",
    "然后，函数遍历`ret`中的每个元素，并根据以下条件进行转换：\n",
    "- 如果当前参数的数据类型与`dtype`不同，则将其转换为`dtype`类型；\n",
    "- 如果当前参数是模拟量化节点且其属性为输入类型，则将其转换为`cfg->dtype_input`类型，并停止融合操作。\n",
    "\n",
    "最后，函数调用`ChooseDomScale`函数来选择两个节点中较小的一个作为它们的共同量化比例，并将结果存储在`dom_scale`变量中。接着，遍历`ret`中的每个元素，将其乘以当前比例和共同比例之间的比值，并将结果存储回`ret`数组中。\n",
    "\n",
    "最后，函数将`dtype`和`dom_scale`分别存储到`dtype_ptr`和`scale_ptr`指向的位置，并返回`ret`数组。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
